# --- SCRIPT: US MODEL v2 (v17 - Corrected 3x3 Baskets) ---
import pandas as pd
import yfinance as yf
import pandas_datareader.data as web
from sklearn.preprocessing import MinMaxScaler
import datetime
import statsmodels.api as sm
import matplotlib.pyplot as plt
from statsmodels.tsa.api import VAR
from statsmodels.tsa.stattools import adfuller

print("--- STARTING SCRIPT: US MODEL v2 (3x3 Baskets) ---", flush=True)
ECON_REGION = "US"

# --- 1. Fetch Economic Data (US) ---
print(f"Fetching {ECON_REGION} economic data...", flush=True)
try:
    start_date = '2004-01-01'
    end_date = '2024-12-31'
    
    data_vix = yf.download('^VIX', start=start_date, end=end_date, interval='1wk')
    data_vix = data_vix[['Close']].rename(columns={'Close': 'VIX'})
    
    data_gspc = yf.download('^GSPC', start=start_date, end=end_date, interval='1wk')
    data_gspc = data_gspc[['Close']].rename(columns={'Close': 'SP500'})

    econ_df_w = pd.concat([data_vix, data_gspc], axis=1)
    # flatten MultiIndex columns from yfinance (Price, Ticker) -> take the renamed level (Close name)
    if hasattr(econ_df_w.columns, 'nlevels') and econ_df_w.columns.nlevels > 1:
        econ_df_w.columns = econ_df_w.columns.get_level_values(0)
    econ_df_w['SP500_Return'] = econ_df_w['SP500'].pct_change()
    # Aggregate econ weekly data to month-end so it aligns with the trends monthly index
    econ_df_w = econ_df_w.resample('M').mean()
    # forward-fill small gaps in econ series to avoid losing rows when we later resample to quarters
    econ_df_w = econ_df_w.ffill()
    print("...US VIX/S&P 500 fetched.", flush=True)

    fred_start = datetime.date(2004, 1, 1)
    fred_end = datetime.date(2024, 12, 31)
    nber_q = web.DataReader('USRECQ', 'fred', fred_start, fred_end)
    nber_q.rename(columns={'USRECQ': 'NBER_RECESSION'}, inplace=True)
    print("...US NBER recession data fetched.", flush=True)
    
    gdp_q = web.DataReader('GDPC1', 'fred', fred_start, fred_end)
    gdp_q['GDP_GROWTH'] = gdp_q['GDPC1'].pct_change(periods=1) * 100
    print("...US Real GDP data fetched.", flush=True)
except Exception as e:
    print(f"*** FATAL ERROR: Could not fetch economic data: {e} ***", flush=True)
    exit()

# --- 2. Load Local Google Trends CSVs (US) ---
print(f"\nLoading local {ECON_REGION} Google Trends CSVs...", flush=True)
scaler = MinMaxScaler()
# We still load all 7 CSVs
KEYWORDS = ['F1', 'Golf', 'Tennis', 'Football', 'Cricket', 'Baseball', 'Basketball']
FILENAME_SUFFIX = "_USA.csv"
all_trends_df = pd.DataFrame()

for kw in KEYWORDS:
    filename = f"{kw}{FILENAME_SUFFIX}" 
    print(f"Attempting to read: '{filename}'", flush=True)
    try:
        # These CSVs have a leading 'Category:' row; read the header line with skiprows=1
        kw_df = pd.read_csv(filename, skiprows=1)
        # Google Trends files use 'Month' as the date column in this dataset
        kw_df['date'] = pd.to_datetime(kw_df['Month'])
        kw_df = kw_df.set_index('date')
        # drop the original 'Month' column if present and coerce the data column to numeric
        if 'Month' in kw_df.columns:
            kw_df = kw_df.drop(columns=['Month'])
        first_col = kw_df.columns[0]
        kw_df[first_col] = pd.to_numeric(kw_df[first_col], errors='coerce')
        # resample to monthly and take the mean to ensure consistent index with econ data
        kw_df = kw_df.resample('M').mean()
        data_col_name = kw_df.columns[0]
        kw_df = kw_df.rename(columns={data_col_name: kw})
        kw_df[kw] = kw_df[kw].fillna(0)
        kw_df[kw] = scaler.fit_transform(kw_df[[kw]])
        if all_trends_df.empty: all_trends_df = kw_df[[kw]]
        else: all_trends_df = all_trends_df.join(kw_df[[kw]], how='outer')
        print(f"--- SUCCESS: Loaded '{filename}' ---", flush=True)
    except Exception as e:
        print(f"--- FAILED: Error for '{filename}': {e}. Skipping. ---", flush=True)

# --- 3. Aggregate and Build Index (NEW BASKETS) ---
print(f"\nConstructing the {ECON_REGION} GSSI (v2 Baskets)...", flush=True)
LX_KEYWORDS = ['F1', 'Golf', 'Tennis'] 
# --- !! NEW US-SPECIFIC MX BASKET !! ---
MX_KEYWORDS = ['Football', 'Baseball', 'Basketball'] 
# ---
lx_cols = [col for col in LX_KEYWORDS if col in all_trends_df.columns]
mx_cols = [col for col in MX_KEYWORDS if col in all_trends_df.columns]

if not lx_cols or not mx_cols:
    print(f"*** FATAL ERROR: One or both baskets are empty. ***")
    print(f"LX Keywords found: {lx_cols}")
    print(f"MX Keywords found: {mx_cols}")
    exit()
    
print(f"LX Basket (final): {lx_cols}")
print(f"MX Basket (final): {mx_cols}")

lx_index_w = all_trends_df[lx_cols].mean(axis=1)
mx_index_w = all_trends_df[mx_cols].mean(axis=1)
df_w = pd.DataFrame({'LX_Index': lx_index_w, 'MX_Index': mx_index_w})
df_w = df_w.join(econ_df_w[['VIX', 'SP500_Return']])
epsilon = 0.01
df_w['MX_LX_Ratio'] = (df_w['MX_Index'] + epsilon) / (df_w['LX_Index'] + epsilon)

# --- 4. Resample, Merge, and Save ---
print("Resampling to quarterly and merging datasets...", flush=True)
df_q = df_w.resample('Q').mean()
# align quarterly indexes to quarter-end timestamps so joins with FRED quarterly series match
df_q.index = df_q.index.to_period('Q').to_timestamp('Q')
gdp_q.index = gdp_q.index.to_period('Q').to_timestamp('Q')
nber_q.index = nber_q.index.to_period('Q').to_timestamp('Q')
final_df = df_q.join(gdp_q).join(nber_q)
# forward-fill quarterly macro series where appropriate and drop any remaining NA rows
if 'NBER_RECESSION' in final_df.columns:
    final_df['NBER_RECESSION'] = final_df['NBER_RECESSION'].ffill()
if 'GDP_GROWTH' in final_df.columns:
    final_df['GDP_GROWTH'] = final_df['GDP_GROWTH'].ffill()
model_df = final_df.dropna()
model_df.to_csv('GSSI_model_dataset_US_v2.csv')
print("...New dataset saved as 'GSSI_model_dataset_US_v2.csv'")

# --- 5. Create the 'Lead' Variable ---
model_df['Recession_Next_Q'] = model_df['NBER_RECESSION'].shift(-1)
model_df = model_df.dropna(subset=['Recession_Next_Q'])

# --- 6. Descriptive Visual ---
print("\nGenerating descriptive plot (v2 Baskets)...")
fig, ax1 = plt.subplots(figsize=(14, 7))
ax1.plot(model_df.index, model_df['MX_LX_Ratio'], color='blue', label='MX/LX Ratio (GSSI_v2)')
ax1.set_ylabel('GSSI (MX/LX Ratio)')
ax1.set_xlabel('Year')
ax1.fill_between(model_df.index, 0, 1, where=model_df['NBER_RECESSION'] == 1,
                   color='red', alpha=0.3, transform=ax1.get_xaxis_transform(), 
                   label='NBER Recession')
ax1.legend(loc='upper left')
plt.title('US GSSI (v2: No Cricket) vs. NBER Recessions')
plt.grid(True)
plt.savefig('GSSI_v2_vs_Recessions_Plot.png')
print("...Plot saved as 'GSSI_v2_vs_Recessions_Plot.png'.")

# --- 7. Model 1: The Full Logit Regression (v2) ---
print("\n--- Running Full Logit Model (v2 Baskets) ---")
Y = model_df['Recession_Next_Q']
X = model_df[['MX_LX_Ratio', 'VIX', 'SP500_Return']]
X = sm.add_constant(X) 
logit_model = sm.Logit(Y, X).fit(disp=0)
print(logit_model.summary())

# --- 8. Model 2: Granger Causality (v2) ---
print("\n--- Running Granger Causality Test (v2 Baskets) ---")
adf_test = adfuller(model_df['MX_LX_Ratio'].dropna())
print(f"ADF test p-value: {adf_test[1]:.4f}.")
if adf_test[1] > 0.05: 
    model_df['Ratio_stationary'] = model_df['MX_LX_Ratio'].diff()
else:
    model_df['Ratio_stationary'] = model_df['MX_LX_Ratio']
var_data = model_df[['GDP_GROWTH', 'Ratio_stationary', 'VIX']].dropna()
var_model = VAR(var_data)
var_results = var_model.fit(maxlags=4, ic='aic')
test_result = var_results.test_causality('GDP_GROWTH', 'Ratio_stationary', kind='f')
print(f"Granger Test (H0: GSSI_v2 does not cause GDP Growth) P-value: {test_result.pvalue:.4f}")

print("\n--- ANALYSIS v2 COMPLETE ---")